{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 神经网络实验"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第 0 步：多层感知机（Multi Layer Perceptron, MLP）算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myMLP3():\n",
    "    \"\"\"\n",
    "        This is a simple implementation for 3-layer neural network (i.e., input layer, hidden layer and output layer)\n",
    "    \"\"\"\n",
    "    def __init__(self, input_layer_size=64, hidden_layer_size=100, output_layer_size=10,\n",
    "                learning_rate=1, epochs=2000):\n",
    "        \"\"\"\n",
    "            MLP3 initilization\n",
    "        \"\"\"\n",
    "        self.input_layer_size = input_layer_size\n",
    "        self.hidden_layer_size = hidden_layer_size\n",
    "        self.output_layer_size = output_layer_size\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epochs = epochs\n",
    "\n",
    "        # 模型参数\n",
    "        self.params = { 'W1': np.random.randn(self.hidden_layer_size, self.input_layer_size),\n",
    "           'b1': np.zeros((self.hidden_layer_size, 1)),\n",
    "           'W2': np.random.randn(self.output_layer_size, self.hidden_layer_size),\n",
    "           'b2': np.zeros((self.output_layer_size, 1))}\n",
    "\n",
    "        # 缓存变量：保留前向计算函数产生的汇聚值（非线性变换单元输入）和激励值（非线性变换单元输出）\n",
    "        self.cache = {}\n",
    "        \n",
    "        # 梯度变量：保留反向传播函数计算得到的模型参数W1,b1,W2,b2的更新梯度\n",
    "        self.grads = {}\n",
    "\n",
    "    def sigmoid(self, z):\n",
    "        \"\"\"\n",
    "            sigmoid function for nonlinear activation\n",
    "        \"\"\"\n",
    "        a = 1 / (1 + np.exp(-z))\n",
    "        return a\n",
    "\n",
    "    def softmax(self, z):\n",
    "        \"\"\"\n",
    "            softmax function for output\n",
    "        \"\"\"\n",
    "        a = np.exp(z) / np.sum(np.exp(z), axis=0)\n",
    "        return a\n",
    "\n",
    "    def compute_multiclass_loss(self, Y, Y_hat):\n",
    "        \"\"\"\n",
    "            loss function for multiclass classification\n",
    "        \"\"\"\n",
    "        # 交叉熵损失\n",
    "        L_sum = np.sum(np.multiply(Y, np.log(Y_hat)))\n",
    "        sample_num = Y.shape[1]\n",
    "        L = -(1/sample_num) * L_sum\n",
    "        return L\n",
    "\n",
    "    def feed_forward(self, X):\n",
    "        \"\"\"\n",
    "            forward computation\n",
    "        \"\"\"\n",
    "        # 前向计算：计算非线性截点的汇聚值（非线性变换单元输入）和激励值（非线性变换单元输出）\n",
    "        self.cache['Z1'] = np.matmul(self.params[\"W1\"], X) + self.params[\"b1\"]\n",
    "        self.cache['A1'] = self.sigmoid(self.cache['Z1'])                                \n",
    "        self.cache['Z2'] = np.matmul(self.params[\"W2\"], self.cache[\"A1\"]) + self.params[\"b2\"]                     \n",
    "        self.cache['A2'] = self.softmax(self.cache['Z2'])\n",
    "\n",
    "    def back_propagate(self, X, Y):\n",
    "        \"\"\"\n",
    "            backward propagation\n",
    "        \"\"\"\n",
    "        # 反向传播：计算模型参数W1,b1,W2,b2的更新梯度\n",
    "        sample_num = X.shape[1]\n",
    "\n",
    "        dZ2 = self.cache[\"A2\"] - Y\n",
    "        dW2 = (1./sample_num) * np.matmul(dZ2, self.cache[\"A1\"].T)\n",
    "        db2 = (1./sample_num) * np.sum(dZ2, axis=1, keepdims=True)\n",
    "        \n",
    "        dA1 = np.matmul(self.params[\"W2\"].T, dZ2)\n",
    "        dZ1 = dA1 * self.sigmoid(self.cache[\"Z1\"]) * (1 - self.sigmoid(self.cache[\"Z1\"]))\n",
    "        dW1 = (1./sample_num) * np.matmul(dZ1, X.T)\n",
    "        db1 = (1./sample_num) * np.sum(dZ1, axis=1, keepdims=True)\n",
    "\n",
    "        self.grads = {\"dW1\": dW1, \"db1\": db1, \"dW2\": dW2, \"db2\": db2}\n",
    "        return self.grads\n",
    "\n",
    "    def train(self, X, Y):\n",
    "        \"\"\"\n",
    "            MLP3 training\n",
    "        \"\"\"\n",
    "        for i in range(epochs):\n",
    "            self.feed_forward(X)\n",
    "            self.back_propagate(X, Y)\n",
    "\n",
    "            # 参数更新：梯度下降法\n",
    "            self.params['W2'] = self.params['W2'] - self.learning_rate * self.grads['dW2']\n",
    "            self.params['b2'] = self.params['b2'] - self.learning_rate * self.grads['db2']\n",
    "            self.params['W1'] = self.params['W1'] - self.learning_rate * self.grads['dW1']\n",
    "            self.params['b1'] = self.params['b1'] - self.learning_rate * self.grads['db1']\n",
    "\n",
    "            if (i % 20 == 0):\n",
    "                print(\"Epoch\", i, \"cost: \", self.compute_multiclass_loss(Y, self.cache[\"A2\"]))\n",
    "            elif (i==epochs-1):\n",
    "                print(\"Epoch\", epochs, \"cost: \", self.compute_multiclass_loss(Y, self.cache[\"A2\"]))\n",
    "        #return self.params\n",
    "\n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "            MLP3 prediction\n",
    "        \"\"\"\n",
    "        self.feed_forward(X)\n",
    "        return np.argmax(self.cache[\"A2\"], axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第 1 步：准备实验数据 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import minmax_scale\n",
    "from sklearn.preprocessing import OneHotEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits_dataset = load_digits()\n",
    "\n",
    "X = minmax_scale(digits_dataset.data) # 特征逐维归一化\n",
    "ohe = OneHotEncoder()\n",
    "y = np.array(ohe.fit_transform(digits_dataset.target.reshape(-1,1)).todense()) # 将类标签转变为one hot型变量\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)\n",
    " \n",
    "np.random.seed(1)\n",
    "shuffle_index = np.random.permutation(X_train.shape[0]) # 对训练集打乱顺序 \n",
    "X_train = X_train[shuffle_index, :]\n",
    "y_train = y_train[shuffle_index, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAuEAAACgCAYAAACi27JyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAck0lEQVR4nO3df0xX1/nA8YcNzMT6A60oigJKtR/BCbjNdjMqnXXOmGqFLjqbiNK4NUsGbUz9wyxIsljXxAE6s5gmU1dTWZpUYVZd6g/U2TCD/MiMVtsKKtjZlYi2Ykchn+8fC3wZhfscvJdzL/p+JU3WnZNzH87nnvt5Cp/P80SEw2EBAAAAYM+3/A4AAAAAeNSQhAMAAACWkYQDAAAAlpGEAwAAAJaRhAMAAACWkYQDAAAAlkX2Z/Ljjz8eTkxMdHXBW7duqXMaGxsdx4cMGaKuMXXqVHVOdHS0Okdz/vz5z8Ph8FgRb/bHREdHh+N4fX29usbw4cPVOePGjTOOqS9+7M/Nmzcdxz/99NMBj0FEZPLkyeqc69evd+2PiDd79MUXX6hzWlpaXK8xYcIEdc6oUaPUOZr+3kPa6//tb39bvWZzc7Pj+P3799U1TEybNs1x3OScen3GtOeLiL7H2v6JiDz22GPqHO0eM3mGd98fEXvPIc2NGzfUOSb7mJaW5joWr+8hk+fHlStXHMdNYhgzZoxpSK54vT8NDQ3qnKFDhzqOmzwbvMhxTAzEGQvScyg5OVmdo+m5R536lYQnJiZKVVWVq0CKi4vVOa+88orjeFxcnLrG22+/rc7x4uEVERFxrfN/e7E/JrQEKicnR11jwYIF6pz8/HzDiPrmx/5s3rzZcbywsHDAYxAR2bRpkzrnF7/4xbXu/+7FHlVUVKhzDh486HoNk31ctmyZOkfT33tIe/1N/sNgz549juN1dXXqGiZ27drlOG5yTr0+Y9rzRUTfY23/RMx+Nu06Js/w7vsjYu85pDF5vprsoxc/i9f3kMnzIzMz03G8oKBAXcPkvc4LXu+PSdzavW1yfrzIcUwMxBkL0nNIe7800XOPOvFxFAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMCyftUJ94JJkfqEhATHcZPalya1H03qUAaRVmO0rKxMXWPt2rUeRWOXSd1PbY5Jzc+amhp1jlYne+7cueoaD0K7b7X6uybmz5+vztm9e7c6x4s64f1VW1vrOG5yPrRnUF5enrqGSR3oIDSN6Wn58uXqHO0ZbPKcN6mVrL2Wtuog92Ty3qG9/nv37rUSixcNs/pLe91MxMTEeBCJP7TXxOQ9SDtDJj1XtDraIvZqrfeXSVzas9yk1rwXOYWbPeQ34QAAAIBlJOEAAACAZSThAAAAgGUk4QAAAIBlJOEAAACAZSThAAAAgGUk4QAAAIBl1uuEm9St1Oocm9QeNqmNqtWY9KPGsUlt4YqKCsfxkSNHqmvU19ebhhQoJvV5tZqeJjXkb9++rc7RamnPmDFDXeNBaPe2yX2r1VdOT09X1zCpJ+0HrUaxSY1eP85+UJicMa2+uckem9Rr96sOuBabyb2vPYe1WvQiIteuXVPn+FEHXGPy+mvPT5NeBUF16tQpx3GT+1p7nzepxW9yn/pVJ1yL3+T5sGbNGsdxk3zT5HnnRd37vvCbcAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyknAAAADAMuvNekwaC3jRKOPOnTvqnLa2NtfX6a933nnHcbykpERdo6amxnHcpNGKX00w3DJpZuQFreGPiMjPf/7zgQ/kAZg0I/KioY3WKMEvY8aMcRzXzo/Io92sp7i4WJ2jnQ+TRiImTcXmzp2rzhkI8+bNcxwvKChQ19AahZg0Sdm7d686J4i0ZjUi+h4GsQmRqWHDhg34NbSGWSJmZ/l3v/ud4/jGjRtNQ+oXk/Ov8aLRkMl99tVXX7m+Tl/4TTgAAABgGUk4AAAAYBlJOAAAAGAZSTgAAABgGUk4AAAAYBlJOAAAAGAZSTgAAABgmfU64V7QahybeuyxxzxZpz92797tOD5r1ix1DZMa1hqT+prPPfec4/j27dtdxzEQli9f7jheVlbmyXWio6Mdx9evX+/JdfrLpAb8K6+84jhu6z4cCKFQyHG8oqLCTiCDlEmdeZM5Xqxx6dIlx/GFCxe6jqM3MTExjuNaDXATLS0trtfwy5EjR1yvob2Pm9SaN6mV7Yfy8nLHce35a1NTU5Mv162srPTlug+itbV1wNbmN+EAAACAZSThAAAAgGUk4QAAAIBlJOEAAACAZSThAAAAgGUk4QAAAIBlJOEAAACAZSThAAAAgGWBbNajNTHIz89X10hISFDnPP3008YxeWXLli2O4yYNLEpKSlzHMWrUKHVORkaG6+v4wYsmGAUFBeocLxp2DAST8zFy5EjHcZP7I6h++9vfOo5rzZxERCIiIhzH8/Ly1DWKi4vVOX7Q7luT197kHtOYNF0LUlMTr5nss3ZO/dLW1uZ6Da1pmklTNZOmYlpzroF41v34xz92HK+pqVHXWLZsmes47ty5o865e/eu6+s8CC/yL+3nM8kFTJ5DA/lez2/CAQAAAMtIwgEAAADLSMIBAAAAy0jCAQAAAMtIwgEAAADLSMIBAAAAy0jCAQAAAMtIwgEAAADLrDfrMSnAX1RU5Dh+6tQpT2JJS0tzHD9//rwn1+nPNb1oNKM1GhER2bNnjzpHizWocnJyHMdN7p+gNuIxkZiYqM7RmkIdPHhQXcPkLHvRcKK/tOYbJve+1tDHpGGWdh+K+HPGtEY7Js2MtAYXJj+7SSMRk6Zrg1VDQ4M6x2SPtNdiIO4xkyY5Gu21NXkvrKurU+doz/uBeEZpa5o0+9LOocn+mFzHr/c67Tk9f/58dQ0tVzxw4IDrOEQG9jnNb8IBAAAAy0jCAQAAAMtIwgEAAADLSMIBAAAAy1wn4Tdu3JDMzEwJhUKSkpJi9IWlR1FHR4ekp6fL0qVL/Q4lcC5fvixpaWld/4wYMUKKi4v9DitQioqKJCUlRVJTU2XVqlXy1Vdf+R1SoCQmJsrMmTMlLS1NMjMz/Q4ncEpKSiQ1NVVSUlLkj3/8o9/hBM66deskNjZWUlNT/Q4lsFpaWiQ7O1uefPJJWbhwoVRXV/sdUqAcPXpUpk+fLsnJybJ161a/wwkk9uibXCfhkZGRsm3bNrl06ZJUVlbKzp075eLFi17E9lApKSmRUCjkdxiBNH36dKmtrZXa2lo5f/68REdHy/PPP+93WIHR1NQk27dvl6qqKrlw4YJ0dHRIaWmp32EFzsmTJ6W2tlZOnjzpdyiBcuHCBXnzzTfl3LlzUldXJ3/729/kk08+8TusQMnJyZGjR4/6HUag5eXlyeLFi+XDDz+Uw4cPS3Jyst8hBUZHR4f86le/kiNHjsjFixdl//795EE9sEe9c52Ex8XFSUZGhoiIDB8+XEKhkDQ1NbkO7GHS2Ngo7733nrz00kt+hxJ4x48fl6lTpz7UpckeRHt7u9y/f1/a29ultbVVJkyY4HdIGCQuXbokTz31lERHR0tkZKT86Ec/kkOHDvkdVqDMmzdPRo8e7XcYgXX37l05ffq05ObmiojIkCFDZMSIET5HFRznzp2T5ORkmTJligwZMkRWrlxpVML1UcIe9c7TOuENDQ1SU1Mjc+bM6XNOTU2Nuo5W19MkQdPqIIvo9VVv3LihrmEiPz9f3njjDfniiy88WU9jsj8mtTH9UFpaKqtWrXK1hlYH2qR2apBMnDhRNmzYIJMnT5ahQ4fKokWLZNGiRX3ON6nxrdVxvnbtmrqGyVm2VSc8IiJCFi1aJBERETJ16lR5+umn+5xr8pE57ec3qZMclDr7qampsmnTJmlubpahQ4fKiRMn5Hvf+16fz4CKigp1Te1nM/lIUEFBgevrDGYm71Fe9cRw6+rVqzJ27FhZu3at1NXVyezZs6WkpESGDRvW63yTZ9CaNWscx03eo0xqXNt4BjU1NcmkSZO6/j0+Pl7+8Y9/9Dlfq9UvIpKenu46Lm2PRcxq+nuhv3tkcg9p+6jleCJmfSMGkmdfzPzyyy8lKytLiouL+S/kbg4dOiSxsbEye/Zsv0MJvLa2NikvL5cXXnjB71AC5fbt21JWVib19fVy8+ZNuXfvnuzbt8/vsALl7NmzUl1dLUeOHJG///3vfNyim1AoJBs3bpRnn31WFi9eLLNmzZLISOt92jCItbe3S3V1tbz88stSU1Mjw4YN4zO93YTD4W/8fyZN8x4l7FHvPEnCv/76a8nKypLVq1fLihUrvFjyoXH27FkpLy+XxMREWblypZw4cUJefPFFv8MKpCNHjkhGRoaMGzfO71AC5dixY5KUlCRjx46VqKgoWbFihXzwwQd+hxUonR/PiY2NlZkzZ8r169d9jihYcnNzpbq6Wk6fPi2jR4+WJ554wu+QMIjEx8dLfHx811+5s7Oz+WJmN/Hx8f/zl/PGxkY+MtgDe9Q710l4OByW3NxcCYVC8uqrr3oR00Pl9ddfl8bGRmloaJDS0lJ55pln+C1mH/bv3+/6oygPo8mTJ0tlZaW0trZKOByW48eP8yXfbu7du9f1Ua979+7JlStXZPz48T5HFSyfffaZiIhcv35d3n33Xc4Z+mX8+PEyadIkuXz5soj897s7M2bM8Dmq4Pj+978vH330kdTX10tbW5uUlpbKc88953dYgcIe9c713yTPnj0rb731Vld5MBGRLVu2yJIlS1wHh0dHa2urvP/++7Jr1y6/QwmcOXPmSHZ2tmRkZEhkZKSkp6fL+vXr/Q4rMG7dutVVTae9vV1CoRD/kdJDVlaWNDc3S1RUlOzcuVNiYmL8DilQVq1aJRUVFfL5559LfHy8FBYWdn0JEf+1Y8cOWb16tbS1tcmUKVNk9+7dfocUGJGRkfKHP/xBfvKTn0hHR4esW7dOUlJS/A4rUNij3rlOwufOndvrZ33wTQsWLDD6Ms6jKDo6Wpqbm/0OI7AKCwulsLDQ7zACacqUKVJXV9f170VFRT5GE0xnzpzxO4RA279/v98hBF5aWppUVVX5HUZgLVmyhF8+Ktijb6JjJgAAAGAZSTgAAABgGUk4AAAAYFlEfz7PHRER8W8R0bt4PFoSwuHwWBH2pw/sj7Ou/RFhj/rAPeSM/XHGGdNxDzljf5xxxnT/s0ed+pWEAwAAAHCPj6MAAAAAlpGEAwAAAJb1q074448/Hk5MTBygUP7fzZs3Hce9qic9ffp0x/EhQ4aoa5w/f/7zzs/52NqftrY2x/Fbt26pa3R20HMyZswYx3GTn9Xr/fn444/VOR0dHY7j2utuU/f9EfFmj0zOR/f2wb0xufdN4oyOjlbnaLy+h7TzIyJSX1/v6hoi9u6z/u5PZ3fRvpg8P+7cueM4PnToUHWN5ORkdY7JfagZiDNmcg91dpfsi1d75AU/3se057TJffjpp5+qc2bOnOk4HtT3+YsXLzqODx8+XF1j0qRJXoXjyK8z9s9//tNxPC4uTl1jwoQJxjG50XOPOvUrCU9MTLRSrH/z5s2O43v27PHkOuXl5Y7jJjdRRERE15cPbO1PQ0OD43hxcbG6RklJiTpn6dKljuMmr4PX+7N8+XJ1TktLi+N4RUWFqxi81H1/RLzZI5PXJT8/33Hc5N43uU5nF103vL6HtPMjIpKTk+PqGiL27rP+7o8Wl8nzo6yszHF82rRp6hoHDx5U53iR7AzEGTO5h7TGbCZnw2SPvODH+5j2nDa5D00amA3W93nt/jBp/Geyh17w64wlJSU5jpt0ltbyTa/03KNOfBwFAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsKxfdcJt0ep2mtTHNKmvqtVKtlWjtTuTup5aXUvt5xIRKSoqUucE8efX6hObMFlj2bJlrq/jF5N7aNSoUY7jJjVaTe6zINVk72RyX586dcpxfOTIkeoaJvujnWXtdXoQ2v1h8prNnz/fcVzbPxFvntF+MaktfO1ar2WBu9TW1noUTfCYPGMLCgocx+vq6jyJxUZjnf4yOWPaM9iLXgZB5kU/mCC+9j3xm3AAAADAMpJwAAAAwDKScAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyknAAAADAskDWCdfqX5rUxzSpDzkQNXg1LS0tjuMm9We1GqNpaWnqGsuXL1fn+CEpKclxfNasWeoaWn3ZmpoadY3BXCdcu8dE9HvEpE64SZ1jrV6wH/tscu4TEhJcr2FS51Z7lpmc5f7S+iyYPBu0uE2evyb3qV+02Pbu3auukZeX5zjux/uPV7R7xKROuFZrX6tFL2JWjz6ITHo5aGfM5PlrcsZM+q74wWSP1qxZ4zg+GGqp85twAAAAwDKScAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyknAAAADAMpJwAAAAwDKScAAAAMAyz5v1XL161XH8z3/+s7qGVoTepAD9tWvX1Dnjx49X53ht69atjuNeNAHRmvmI6A1tREQ2btyozvGa1rzFpIFDTEyMV+EMSibNVkpKShzHtUYaIiL5+fnqnCA2PTLZn7Vr1zqOmzxfTp48qc4ZiGY8GpPXTaM9g032x6Shj1+qqqpcr5GZmelBJMH0m9/8xnHc5L7W7kOTZi1BbdbzzjvvOI6bNDPSmsqZNKIxedZpDfAG6v302LFjjuN37txR1zBpbhh0/CYcAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsMzzOuFabc8dO3Z4fckH9p3vfMf6NV977TXHcS9ql2t1xEXMaoya1OS2zeRn06Snp7sPxEcff/yx47hWA9yESY1Wk3soiLy4hxISEtQ5CxYscH2doNJqPK9Zs0ZdI8j3z7/+9S/XaxQVFbkaFzHrG/GnP/3JcXz06NHqGv01e/ZsV+M2nTt3znH8Bz/4gefX/NnPfuZ6De05ZdLrw2SO9n6yb98+dY0HUV5e7nqNpKQk12uYPMu1mubJyckPfH1+Ew4AAABYRhIOAAAAWEYSDgAAAFhGEg4AAABYRhIOAAAAWEYSDgAAAFhGEg4AAABYRhIOAAAAWOZ5s57t27e7GjexfPlydU5FRYU6R2s4MRC0xglexFRWVuZ6DRGRn/70p56s46WGhgZ1zsiRIx3Hg9iEqD+8aJ5QX1/vOG7SaGbz5s3qHC8a4/RXS0uL47hJ3Hl5eY7jJg2RTM7hsmXL1Dm2aQ3XRET27t3rOH7y5El1jdraWuOY+pKWluZ6jd7Ex8e7XkN7VpmcMZP3sXXr1jmOHzx4UF3jYdba2mr9mtp7zKlTp9Q1vGi6ZvJ82bBhg+vrPAgvmkhpDa9Mng+ZmZnqHO091+Q9pS/8JhwAAACwjCQcAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsIwkHAAAALCMJBwAAACwjCQcAAAAsMzzZj1e0BrWmDTB0Iq4i4iMGjXKOKbB5MCBA+ock0YRQWTSwEBr1mNSWD8xMVGdozWNMlnjQaSnp7tew6Qhi0ZriuMXLS6T18WkSYrG5Bz60axHa5LjpvFEJ5OGaia059RANaLR7hHtGSOix27yDNaaIpmuMxh59f6s3e8DsX/a88Pk+aKdQ5PGdUFu1KTleYWFheoa2h54dQ8N5BnjN+EAAACAZSThAAAAgGUk4QAAAIBlJOEAAACAZSThAAAAgGWuk/DLly9LWlpa1z8jRozwpPLCw6SkpERSU1MlJSWFvenhxo0bkpmZKaFQSFJSUoyqnzyKioqKJCUlRVJTU2Xbtm3S1tbmd0iB0n1/fv3rX8t//vMfv0MKjJ5n7O233/Y7pEBqaWmR7OxsefLJJ2XhwoVSXV3td0iBsm7dOomNjZXU1FS/Qwmsjo4OSU9Pl6VLl/odSiB1P2Nz5syRc+fO+R2S71wn4dOnT5fa2lqpra2V8+fPS3R0tDz//PNexPZQuHDhgrz55pty7tw5qaurk0OHDslHH33kd1iBERkZKdu2bZNLly5JZWWl7Ny5Uy5evOh3WIHS1NQk27dvl6qqKrlw4YJ0dHTImTNn/A4rMHrbn7/+9a9+hxUYPc/YX/7yF/nkk0/8Ditw8vLyZPHixfLhhx/K4cOHJTk52e+QAiUnJ0eOHj3qdxiBVlJSIqFQyO8wAqv7GTtz5oxMnz7d75B852md8OPHj8vUqVMlISHB1TpaXc9Zs2apa+Tk5LiKwSuXLl2Sp556SqKjo0VEZP78+XLgwAF57bXXBuyaJjVI9+zZM2DX74+4uDiJi4sTEZHhw4dLKBSSpqYmmTFjRq/z58+fr66p1Q412R+T+qJJSUmO417WCW9vb5f79+9LVFSUxMTEyNKlS2XRokW9zi0oKFDX015/k58/SH/V6b4/ERER8t3vfrfP/Td5/bXnh8n+BOWXET3PWEJCgjQ1NcnEiRN7nW9y32o/f1pamrpGkO6fu3fvyunTp7vOxbRp0xznm9xDWm1hkxrga9asUefY2sd58+YZ1ab2ipYHmLLVz6CxsVHee+892bRpk/z+9793nOtF3enBVh++5xmLjY11nF9TU6OuqdUaN/nLuklOMWjqhJeWlsqqVau8XHLQS01NldOnT0tzc7O0trbK4cOH5caNG36HFUgNDQ1SU1Mjc+bM8TuUQJk4caJs2LBBJk+eLHFxcTJy5Mg+E/BHEftjrqGhQa5cuSIpKSl+hxIoV69elbFjx8ratWslPT1dXnrpJbl3757fYWEQyc/PlzfeeEO+9S2+atcbzljvPLtb2trapLy8XF544QWvlnwohEIh2bhxozz77LOyePFimTVrlkRGBrJRqa++/PJLycrKkuLiYhkxYoTf4QTK7du3paysTOrr6+XmzZty79492bdvn99hBQb7Y6bzjOXn58uwYcP8DidQ2tvbpbq6Wl5++WWpqamRYcOGydatW/0OC4PEoUOHJDY2VmbPnu13KIHFGeudZ0n4kSNHJCMjQ8aNG+fVkg+N3Nxcqa6ultOnT8vo0aPliSee8DukQPn6668lKytLVq9eLStWrPA7nMA5duyYJCUlydixYyUqKkpWrFghH3zwgd9hBQb7o+t+xgbbn7FtiI+Pl/j4+K6/wmVnZ/PFTBg7e/aslJeXS2JioqxcuVJOnDghL774ot9hBQpnrHeeJeH79+/noyh9+Oyzz0RE5Pr16/Luu++yT92Ew2HJzc2VUCgkr776qt/hBNLkyZOlsrJSWltbJRwOy/Hjx/nyTzfsjzPOmG78+PEyadIkuXz5soj89/tNfX0vBejp9ddfl8bGRmloaJDS0lJ55pln+GtcD5yx3nnyuYjW1lZ5//33ZdeuXV4s99DJysqS5uZmiYqKkp07d0pMTIzfIQXG2bNn5a233pKZM2d2fZlry5YtsmTJEp8jC445c+ZIdna2ZGRkSGRkpKSnp8v69ev9Disw2B9nPc9Ya2ur/PKXv5Qf/vCHfocWKDt27JDVq1dLW1ubTJkyRXbv3u13SIGyatUqqaiokM8//1zi4+OlsLBQcnNz/Q4Lgwhn7Js8ScKjo6OlubnZi6UeSpST69vcuXMlHA77HUbgFRYWSmFhod9hBBb707eeZ6yystLHaIIrLS1Nqqqq/A4jsPbv3+93CIPCggUL+MhXHzhj38TXeAEAAADLSMIBAAAAyyL681GAiIiIf4vItYELZ1BKCIfDY0XYnz6wP8669keEPeoD95Az9scZZ0zHPeSM/XHGGdP9zx516lcSDgAAAMA9Po4CAAAAWEYSDgAAAFhGEg4AAABYRhIOAAAAWEYSDgAAAFhGEg4AAABYRhIOAAAAWEYSDgAAAFhGEg4AAABY9n9L3BVSWnbhrwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 720x144 with 20 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 绘制样例图片\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    " \n",
    "fig = plt.figure(figsize=(10, 2))\n",
    "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n",
    "\n",
    "for i in range(20):\n",
    "    ax = fig.add_subplot(2, 10, i + 1, xticks=[], yticks=[])\n",
    "    ax.imshow(X_train[i,:].reshape(8,8), cmap=plt.cm.binary, interpolation='nearest')\n",
    "    ax.text(0, 7, '{}'.format(np.argwhere(y_train[i,:]==1)[0][0]))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第 2 步：使用反向传播（Back propagation, BP）算法训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 cost:  14.688652865179417\n",
      "Epoch 20 cost:  0.8245247058208665\n",
      "Epoch 40 cost:  0.40402247012741915\n",
      "Epoch 60 cost:  0.26273335749492155\n",
      "Epoch 80 cost:  0.20847926641516004\n",
      "Epoch 100 cost:  0.17900463964161958\n",
      "Epoch 120 cost:  0.15806566509070102\n",
      "Epoch 140 cost:  0.14216942865720078\n",
      "Epoch 160 cost:  0.1295482583554408\n",
      "Epoch 180 cost:  0.11919695227463499\n",
      "Epoch 200 cost:  0.11090018825265255\n"
     ]
    }
   ],
   "source": [
    "# 实验参数：选择不同的隐含层结点个数、学习率和训练循环迭代次数\n",
    "hidden_layer_size = 100 # 隐含层结点个数\n",
    "learning_rate = 1 # 学习率\n",
    "epochs = 200 # BP算法循环迭代次数\n",
    "\n",
    "input_layer_size = X_train.shape[1]  # 输入层结点数，等于64 （8*8图像像素）\n",
    "output_layer_size = y_train.shape[1] # 输出层结点数，等于10 （0,1,...,9个类别） \n",
    "\n",
    "# 模型初始化及训练\n",
    "mlp3 = myMLP3(input_layer_size, hidden_layer_size, output_layer_size, learning_rate, epochs) # 模型初始化\n",
    "mlp3.train(X_train.T, y_train.T) # 模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第 3 步：模型测试并输出测试结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "confusion matrix: \n",
      " [[20  0  0  0  0  0  0  0  0  0]\n",
      " [ 0 18  0  0  0  0  1  0  0  0]\n",
      " [ 0  0 21  0  0  0  0  0  0  0]\n",
      " [ 0  0  0 22  0  0  0  0  0  0]\n",
      " [ 0  0  0  0 18  0  0  0  0  0]\n",
      " [ 0  0  0  0  0 17  0  0  0  0]\n",
      " [ 0  0  0  0  0  0 16  0  0  0]\n",
      " [ 0  0  0  1  0  0  0 20  0  0]\n",
      " [ 0  0  0  0  0  1  0  0 14  0]\n",
      " [ 0  0  0  0  0  1  0  0  0 10]]\n",
      "\n",
      "classification report: \n",
      "               precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        20\n",
      "           1       1.00      0.95      0.97        19\n",
      "           2       1.00      1.00      1.00        21\n",
      "           3       0.96      1.00      0.98        22\n",
      "           4       1.00      1.00      1.00        18\n",
      "           5       0.89      1.00      0.94        17\n",
      "           6       0.94      1.00      0.97        16\n",
      "           7       1.00      0.95      0.98        21\n",
      "           8       1.00      0.93      0.97        15\n",
      "           9       1.00      0.91      0.95        11\n",
      "\n",
      "    accuracy                           0.98       180\n",
      "   macro avg       0.98      0.97      0.98       180\n",
      "weighted avg       0.98      0.98      0.98       180\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 模型测试\n",
    "predictions = mlp3.predict(X_test.T) # 预测标签\n",
    "labels = np.argmax(y_test.T, axis=0) # 标答标签\n",
    "\n",
    "print('confusion matrix: \\n', confusion_matrix(labels, predictions)) # 输出分类混淆矩阵\n",
    "print('\\nclassification report: \\n', classification_report(labels, predictions)) # 输出分类报告"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第 4 步：自选测试样本，用学习得到的MLP进行分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "selected image category is 0\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAKyElEQVR4nO3dX4hc5RnH8d+vq9L6J6hNKJINXRckIIWauAQkIDR2S6yivaiSQIVKwZsqSgtGe9c7vRF7UQSJWsFEyUYFEasVVFqhte7GtDWuliSmZGtsEhrxT6Uh+vRiJxDt2j0zc857zj79fiC4szvs+wzJ1zMze/a8jggByONLbQ8AoF5EDSRD1EAyRA0kQ9RAMqc18U2XL18eY2NjTXzrVh07dqzoenNzc8XWWrZsWbG1RkdHi601MjJSbK2SDhw4oKNHj3qhrzUS9djYmKanp5v41q2ampoqut6WLVuKrTU5OVlsrbvuuqvYWuedd16xtUqamJj4wq/x9BtIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZS1LY32n7L9l7bdzQ9FIDBLRq17RFJv5R0paSLJW22fXHTgwEYTJUj9TpJeyNif0Qcl/SYpGubHQvAoKpEvVLSwVNuz/U+9xm2b7I9bXv6yJEjdc0HoE9Vol7o17v+62qFEXF/RExExMSKFSuGnwzAQKpEPSdp1Sm3RyW908w4AIZVJepXJV1k+0LbZ0jaJOmpZscCMKhFL5IQESds3yzpOUkjkh6MiD2NTwZgIJWufBIRz0h6puFZANSAM8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZBrZoSOrkjtmSNLbb79dbK2SWwqdf/75xdbasWNHsbUk6brrriu63kI4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEyVHToetH3Y9uslBgIwnCpH6l9J2tjwHABqsmjUEfFbSf8sMAuAGtT2mpptd4BuqC1qtt0BuoF3v4FkiBpIpsqPtB6V9HtJq23P2f5R82MBGFSVvbQ2lxgEQD14+g0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0ks+S33ZmZmSm2VsltcCRp3759xdYaHx8vttbk5GSxtUr++5DYdgdAA4gaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWSIGkimyjXKVtl+0fas7T22by0xGIDBVDn3+4Skn0bELtvnSJqx/XxEvNHwbAAGUGXbnUMRsav38QeSZiWtbHowAIPp6zW17TFJayS9ssDX2HYH6IDKUds+W9Ljkm6LiPc//3W23QG6oVLUtk/XfNDbIuKJZkcCMIwq735b0gOSZiPinuZHAjCMKkfq9ZJukLTB9u7en+82PBeAAVXZdudlSS4wC4AacEYZkAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8ks+b20Sv5G2Nq1a4utJZXd36qkSy+9tNhaH3/8cbG1uoIjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQTJULD37Z9h9t/6m37c7PSwwGYDBVThP9t6QNEfFh71LBL9v+dUT8oeHZAAygyoUHQ9KHvZun9/5Ek0MBGFzVi/mP2N4t6bCk5yOCbXeAjqoUdUR8EhGXSBqVtM72Nxa4D9vuAB3Q17vfEfGepJckbWxkGgBDq/Lu9wrb5/Y+/oqkb0t6s+nBAAymyrvfF0h62PaI5v8nsCMinm52LACDqvLu9581vyc1gCWAM8qAZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSGbJb7tz6NChYmtNTk4WWyuzY8eOFVvrxIkTxdbqCo7UQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kUznq3gX9X7PNRQeBDuvnSH2rpNmmBgFQj6rb7oxKukrS1mbHATCsqkfqeyXdLunTL7oDe2kB3VBlh46rJR2OiJn/dT/20gK6ocqRer2ka2wfkPSYpA22H2l0KgADWzTqiLgzIkYjYkzSJkkvRMQPGp8MwED4OTWQTF+XM4qIlzS/lS2AjuJIDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSSz5LfdWbVqVbG1tm/fXmyt0kpuhTM9PV1sreuvv77YWl3BkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQqnSbau5LoB5I+kXQiIiaaHArA4Po59/tbEXG0sUkA1IKn30AyVaMOSb+xPWP7poXuwLY7QDdUjXp9RKyVdKWkH9u+/PN3YNsdoBsqRR0R7/T+e1jSk5LWNTkUgMFV2SDvLNvnnPxY0nckvd70YAAGU+Xd769JetL2yftvj4hnG50KwMAWjToi9kv6ZoFZANSAH2kByRA1kAxRA8kQNZAMUQPJEDWQDFEDySz5bXfGx8eLrVVyuxhJmpqaSrlWSVu2bGl7hOI4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kEylqG2fa3un7Tdtz9q+rOnBAAym6rnfv5D0bER83/YZks5scCYAQ1g0atvLJF0u6YeSFBHHJR1vdiwAg6ry9Htc0hFJD9l+zfbW3vW/P4Ntd4BuqBL1aZLWSrovItZI+kjSHZ+/E9vuAN1QJeo5SXMR8Urv9k7NRw6ggxaNOiLelXTQ9urep66Q9EajUwEYWNV3v2+RtK33zvd+STc2NxKAYVSKOiJ2S5poeBYANeCMMiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSYS+tPtx9993F1pLK7gM1MVHu3KKZmZlia/0/4kgNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSSzaNS2V9vefcqf923fVmI4AP1b9DTRiHhL0iWSZHtE0t8lPdnwXAAG1O/T7ysk7YuIvzUxDIDh9Rv1JkmPLvQFtt0BuqFy1L1rfl8jaWqhr7PtDtAN/Rypr5S0KyL+0dQwAIbXT9Sb9QVPvQF0R6WobZ8paVLSE82OA2BYVbfd+ZekrzY8C4AacEYZkAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8k4Iur/pvYRSf3+euZySUdrH6Ybsj42Hld7vh4RC/7mVCNRD8L2dESU29CpoKyPjcfVTTz9BpIhaiCZLkV9f9sDNCjrY+NxdVBnXlMDqEeXjtQAakDUQDKdiNr2Rttv2d5r+46256mD7VW2X7Q9a3uP7VvbnqlOtkdsv2b76bZnqZPtc23vtP1m7+/usrZn6lfrr6l7GwT8VfOXS5qT9KqkzRHxRquDDcn2BZIuiIhdts+RNCPpe0v9cZ1k+yeSJiQti4ir256nLrYflvS7iNjau4LumRHxXttz9aMLR+p1kvZGxP6IOC7pMUnXtjzT0CLiUETs6n38gaRZSSvbnaoetkclXSVpa9uz1Mn2MkmXS3pAkiLi+FILWupG1CslHTzl9pyS/OM/yfaYpDWSXml3ktrcK+l2SZ+2PUjNxiUdkfRQ76XFVttntT1Uv7oQtRf4XJqfs9k+W9Ljkm6LiPfbnmdYtq+WdDgiZtqepQGnSVor6b6IWCPpI0lL7j2eLkQ9J2nVKbdHJb3T0iy1sn265oPeFhFZLq+8XtI1tg9o/qXSBtuPtDtSbeYkzUXEyWdUOzUf+ZLShahflXSR7Qt7b0xskvRUyzMNzbY1/9psNiLuaXueukTEnRExGhFjmv+7eiEiftDyWLWIiHclHbS9uvepKyQtuTc2K133u0kRccL2zZKekzQi6cGI2NPyWHVYL+kGSX+xvbv3uZ9FxDMtzoTF3SJpW+8As1/SjS3P07fWf6QFoF5dePoNoEZEDSRD1EAyRA0kQ9RAMkQNJEPUQDL/AROstwgyTzaqAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 实验参数：自选其他的图像，然后用学习得到的MLP分类器进行预测，比对预测结果与实际结果的一致性\n",
    "new_idx = 0;\n",
    "assert my_idx in range(y.shape[0])\n",
    "X_new = X[new_idx]\n",
    "y_new = y[new_idx]\n",
    "print('selected image category is %d' % np.where(y_new==1)[0][0])\n",
    "plt.imshow(X_new.reshape(8,8), cmap=plt.cm.binary, interpolation='nearest')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP predicted category is 0\n"
     ]
    }
   ],
   "source": [
    "y_predict = mlp3.predict(my_X.reshape(-1,1))\n",
    "print('MLP predicted category is %d' % y_predictprediction)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
